{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "import onnxruntime\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)\n",
    "        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.linear1 = nn.Linear(64 * 4 * 4, 512)\n",
    "        self.linear2 = nn.Linear(512, 10) \n",
    "        self.dropout = nn.Dropout(p=0.3)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = self.pool(F.relu(self.conv3(x)))\n",
    "        x = x.view(-1, 64 * 4 * 4)\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.linear1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.linear2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNN()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load(\"cifar10.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (linear1): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (linear2): Linear(in_features=512, out_features=10, bias=True)\n",
       "  (dropout): Dropout(p=0.3, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(1, 3, 32, 32, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_out = model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 3.5831, 10.1674, -4.0835, -4.4179, -4.0980, -5.8097, -2.0841, -1.6725,\n",
       "          1.0503,  8.2375]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.onnx.export(model,\n",
    "                 x,\n",
    "                 \"cifar.onnx\",\n",
    "                 export_params=True,\n",
    "                 opset_version=10,         \n",
    "                 do_constant_folding=True, \n",
    "                 input_names = ['input'],  \n",
    "                 output_names = ['output'],\n",
    "                 dynamic_axes={'input' : {0 : 'batch_size'},\n",
    "                               'output' : {0 : 'batch_size'}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model = onnx.load(\"cifar.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx.checker.check_model(onnx_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_session = onnxruntime.InferenceSession(\"cifar.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_numpy(tensor):\n",
    "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}\n",
    "ort_outs = ort_session.run(None, ort_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(to_numpy(model_out), ort_outs[0], rtol=1e-03, atol=1e-05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
